def ntksap(train_loader, networks):

    def perturb(model_orig, model_copy):
        with torch.no_grad():
            for (m_orig, p_orig), (m_copy, p_copy) in zip(generator.masked_parameters(model_orig), generator.masked_parameters(model_copy)):
                p_copy.data = p_orig.data + self.epsilon * torch.randn_like(p_orig.data)

        for module, module_mod in zip(model_orig.modules(), model_copy.modules()):
            if isinstance(module, nn.BatchNorm2d):
                with torch.no_grad():
                    module_mod.running_mean = module.running_mean
                    module_mod.running_var = module.running_var
                    module_mod.num_batches_tracked = module.num_batches_tracked

    for m, p in self.masked_parameters:
        m.requires_grad = True
        p.requires_grad = False

    # Copy a same model
    model_mod = copy.deepcopy(model)

    # Set model mod to evaluation mode
    model_mod.eval()

    # Make two models share the same weight masks
    for module, module_mod in zip(model.modules(), model_mod.modules()):
        if hasattr(module, 'weight_mask'):
            module_mod.weight_mask = module.weight_mask
        if isinstance(module, nn.BatchNorm2d):
            module.momentum = 1.0
            module_mod.momentum = 1.0

    for _ in range(self.R):
        for index, (data, _) in enumerate(dataloader):
            if isinstance(model, nn.DataParallel):
                model.module._initialize_weights()
            else:
                model._initialize_weights()
            input = torch.randn_like(data).to(device)

            reset_BN(model)
            with torch.no_grad():
                output_orig = model(input)

            model.eval()
            # Compute the true graph using eval mode
            output_orig = model(input)
            perturb(model, model_mod)
            output_mod = model_mod(input)
            jac_approx = (torch.norm(output_orig-output_mod,dim=-1)**2).sum()
            jac_approx.backward()
            model.train()

    for m, p in self.masked_parameters:
        self.scores[id(p)] = torch.clone(m.grad * (m!=0)).detach().abs_()
        m.grad.data.zero_()
        m.requires_grad = False
        p.requires_grad = True

    # Reset momentum of BatchNorm2d
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.momentum = 0.1

    del model_mod